Text classification using recurrent neural networks

This example shows how to use recurrent neural networks (with and without attention) to classify documents.

We use our usual sentiment analysis benchmark.

In [1]:
import torch
from torch import nn
import time
import torchtext

import random

from collections import defaultdict

import matplotlib.pyplot as plt

%config InlineBackend.figure_format = 'retina' 
plt.style.use('seaborn')

A text classifier based on RNNs

Let's define an RNN-based text classifier. We'll apply a bidirectional RNN and then base the classification on the last state in both directions.

Drawing

We'll optionally use pre-trained embeddings, which are assumed to be stored with the torchtext vocabulary object.

In [2]:
class RNNTextClassifier(nn.Module):
    
    def __init__(self, text_field, class_field, emb_dim, rnn_size, update_pretrained=False):
        super().__init__()        
        
        voc_size = len(text_field.vocab)
        n_classes = len(class_field.vocab)       
        
        # Embedding layer.
        self.embedding = nn.Embedding(voc_size, emb_dim)

        # If we're using pre-trained embeddings, copy them into the model's embedding layer.
        if text_field.vocab.vectors is not None:
            self.embedding.weight = torch.nn.Parameter(text_field.vocab.vectors, 
                                                       requires_grad=update_pretrained)
        
        # The RNN module: either a basic RNN, LSTM, or a GRU.
        #self.rnn = nn.RNN(input_size=emb_dim, hidden_size=rnn_size, 
        #                  bidirectional=True, num_layers=1)        
        #self.rnn = nn.LSTM(input_size=emb_dim, hidden_size=rnn_size, 
        #                   bidirectional=True, num_layers=1)
        self.rnn = nn.GRU(input_size=emb_dim, hidden_size=rnn_size, 
                          bidirectional=True, num_layers=1)

        # And finally, a linear layer on top of the RNN layer to produce the output.
        self.top_layer = nn.Linear(2*rnn_size, n_classes)
        
    def forward(self, texts):
        # The words in the documents are encoded as integers. The shape of the documents
        # tensor is (max_len, n_docs), where n_docs is the number of documents in this batch,
        # and max_len is the maximal length of a document in the batch.

        # First look up the embeddings for all the words in the documents.
        # The shape is now (max_len, n_docs, emb_dim).
        embedded = self.embedding(texts)
        
        # The RNNs return two tensors: one representing the outputs at all positions
        # of the final layer, and another representing the final states of each layer.
        # In this example, we'll use just the final states.
        # NB: for a bidirectional RNN, the final state corresponds to the *last* token
        # in the forward direction and the *first* token in the backward direction.
        rnn_out, final_state = self.rnn(embedded)
        
        # The shape of final_state is (2*n_layers, n_docs, rnn_size), assuming that 
        # the RNN is bidirectional.
        # We select the top layer's forward and backward states and concatenate them.
        top_forward = final_state[-2]
        top_backward = final_state[-1]
        top_both = torch.cat([top_forward, top_backward], dim=1)
        
        # Apply the linear layer and return the output.
        return self.top_layer(top_both)

Adding an attention model

We now add an attention model, which will compute a weighted average of all the state vectors. These weights are based on an "importance" score computed by a neural network.

Drawing

We first define the attention model, and then the text classifier that uses it. We describe the attention model in detail, while the classification model should be fairly self-explanatory.

In [3]:
class SimpleAttention(nn.Module):
    def __init__(self, rnn_size):
        super().__init__()

        # This is the neural network that computes the attention scores.
        # To keep things simple, we'll use a linear model here.
        self.attn_nn = nn.Linear(rnn_size, 1)
        
    def forward(self, rnn_output):
        # The input to the attention model is the output from the top layer of the RNN,
        # which is a tensor containing the states for each position in each document.
        # The shape of this tensor is (n_words, n_docs, rnn_dim).
        
        # First, we apply the attention neural network to each state in the RNN output.
        e = self.attn_nn(rnn_output)
        
        # The shape is now (n_words, n_docs, 1). The squeeze method will reshape
        # the tensor to (n_words, n_docs).
        e = e.squeeze()
        
        # Compute attention weights by applying the softmax over the rows.
        # This tensor has the same shape as e.
        alpha = torch.softmax(e, dim=0)
                
        # We weigh each RNN state by its attention weight.
        # In order to carry out the element-wise multiplication, we need to "flip"
        # the tensor so that the RNN state dimension comes first.
        # This tensor has the shape (rnn_dim, n_words, n_docs).
        weighted = alpha * rnn_output.permute(2, 0, 1)
        
        # Compute a weighted sum of the RNN state vectors. We sum over the word dimension.
        # The shape is now (rnn_dim, n_docs).
        out = weighted.sum(dim=1)
        
        # "Flip" the tensor back to the shape (n_docs, rnn_dim) so that it fits
        # with the linear layer in the text classifier.
        return out.t()
        
        
class RNNAttentionTextClassifier(nn.Module):
    
    def __init__(self, text_field, class_field, emb_dim, rnn_size, update_pretrained=False):
        super().__init__()        
                
        voc_size = len(text_field.vocab)
        n_classes = len(class_field.vocab)       
        
        self.embedding = nn.Embedding(voc_size, emb_dim)
        if text_field.vocab.vectors is not None:
            self.embedding.weight = torch.nn.Parameter(text_field.vocab.vectors, 
                                                       requires_grad=update_pretrained)
        
        #self.rnn = nn.RNN(input_size=emb_dim, hidden_size=rnn_size, 
        #                  bidirectional=True, num_layers=1)
        #self.rnn = nn.LSTM(input_size=emb_dim, hidden_size=rnn_size, 
        #                   bidirectional=True, num_layers=1)
        self.rnn = nn.GRU(input_size=emb_dim, hidden_size=rnn_size, 
                          bidirectional=True, num_layers=1)

        self.attention = SimpleAttention(2*rnn_size)        
        self.top_layer = nn.Linear(2*rnn_size, n_classes)
        
    def forward(self, texts):
        embedded = self.embedding(texts)
        rnn_out, final_state = self.rnn(embedded)
        
        # The attention model returns the weighted sum of RNN states for each document.
        # The shape is (n_docs, 2*rnn_size).
        attention_out = self.attention(rnn_out)

        # Apply the linear layer and return the output.
        return self.top_layer(attention_out)

Training the text classifier

We train the classifier and evaluate on the validation set. This code is almost identical to the code that we saw in the CNN lecture.

For the first RNN-based classifier, the performance tends to be a bit lower than for the CNN from Lecture 2. When we add attention, the performance is usually slightly better than the other models, peaking at about 0.86-0.87 on the validation set. However, the performance for both models seems a bit "jumpy" and can vary between runs.

A note on pre-trained word embeddings. We're now using pre-trained embeddings. We use the built-in model glove.6B.100d that is bundled with torchtext. The first time you run this code, the GloVe model will be downloaded, which will take some time. This downloading step will not be necessary when you run the code subsequently. To use the pre-trained embeddings, they need to be copied into the neural network's parameters (see above).

In [4]:
def read_data(corpus_file, datafields, label_column, doc_start):
    with open(corpus_file, encoding='utf-8') as f:
        examples = []
        for line in f:
            columns = line.strip().split(maxsplit=doc_start)
            doc = columns[-1]
            label = columns[label_column]
            examples.append(torchtext.data.Example.fromlist([doc, label], datafields))
    return torchtext.data.Dataset(examples, datafields)

def evaluate_validation(scores, loss_function, gold):
    guesses = scores.argmax(dim=1)
    n_correct = (guesses == gold).sum().item()
    return n_correct, loss_function(scores, gold).item()

def main():
   
    TEXT = torchtext.data.Field(sequential=True, tokenize=lambda x: x.split())
    LABEL = torchtext.data.LabelField(is_target=True)
    datafields = [('text', TEXT), ('label', LABEL)]
    
    random.seed(0)
    
    data = read_data('data/all_sentiment_shuffled.txt', datafields, label_column=1, doc_start=3)
    train, valid = data.split([0.8, 0.2])

    use_pretrained = True

    if use_pretrained:
        print('We are using pre-trained word embeddings.')
        TEXT.build_vocab(train, vectors="glove.6B.100d")
    else:        
        print('We are training word embeddings from scratch.')
        TEXT.build_vocab(train, max_size=10000)
    LABEL.build_vocab(train)
        
    # Declare the RNN classifier.
    #model = RNNTextClassifier(TEXT, LABEL, emb_dim=100, rnn_size=64, update_pretrained=True)
    model = RNNAttentionTextClassifier(TEXT, LABEL, emb_dim=100, rnn_size=64, update_pretrained=True)
    
    device = 'cuda'
    model.to(device)    
    
    train_iterator = torchtext.data.BucketIterator(
        train,
        device=device,
        batch_size=128,
        sort_key=lambda x: len(x.text),
        repeat=False,
        train=True,
        sort=True)
    
    valid_iterator = torchtext.data.BucketIterator(
        valid,
        device=device,
        batch_size=128,
        sort_key=lambda x: len(x.text),
        repeat=False,
        train=False,
        sort=True)
    
    loss_function = torch.nn.CrossEntropyLoss()    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0025, weight_decay=1e-4)    
    
    train_batches = list(train_iterator)
    valid_batches = list(valid_iterator)
    
    history = defaultdict(list)

    for i in range(25):
        
        t0 = time.time()
        
        loss_sum = 0
        n_batches = 0

        model.train()
        
        for batch in train_batches:
                        
            scores = model(batch.text)
            loss = loss_function(scores, batch.label)

            optimizer.zero_grad()            
            loss.backward()
            optimizer.step()
    
            loss_sum += loss.item()
            n_batches += 1
        
        train_loss = loss_sum / n_batches
        history['train_loss'].append(train_loss)
        
        n_correct = 0
        n_valid = len(valid)
        loss_sum = 0
        n_batches = 0

        model.eval()
        
        for batch in valid_batches:
            scores = model(batch.text)
            n_corr_batch, loss_batch = evaluate_validation(scores, loss_function, batch.label)
            loss_sum += loss_batch
            n_correct += n_corr_batch
            n_batches += 1
        val_acc = n_correct / n_valid
        val_loss = loss_sum / n_batches

        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)        
        
        t1 = time.time()
        print(f'Epoch {i+1}: train loss = {train_loss:.4f}, val loss = {val_loss:.4f}, val acc: {val_acc:.4f}, time = {t1-t0:.4f}')

    plt.plot(history['train_loss'])
    plt.plot(history['val_loss'])
    plt.plot(history['val_acc'])
    plt.legend(['training loss', 'validation loss', 'validation accuracy'])
    
main()
We are using pre-trained word embeddings.
Epoch 1: train loss = 0.5580, val loss = 0.6053, val acc: 0.7067, time = 1.8969
Epoch 2: train loss = 0.3726, val loss = 0.3975, val acc: 0.8334, time = 1.8331
Epoch 3: train loss = 0.2426, val loss = 0.3219, val acc: 0.8720, time = 1.7375
Epoch 4: train loss = 0.1202, val loss = 0.6293, val acc: 0.8099, time = 1.8366
Epoch 5: train loss = 0.1230, val loss = 0.3535, val acc: 0.8666, time = 1.7783
Epoch 6: train loss = 0.0762, val loss = 0.8362, val acc: 0.7944, time = 1.9434
Epoch 7: train loss = 0.0546, val loss = 0.4509, val acc: 0.8624, time = 1.7927
Epoch 8: train loss = 0.0518, val loss = 0.6143, val acc: 0.8267, time = 1.8181
Epoch 9: train loss = 0.0526, val loss = 0.4963, val acc: 0.8649, time = 1.8636
Epoch 10: train loss = 0.0125, val loss = 0.5593, val acc: 0.8649, time = 1.7718
Epoch 11: train loss = 0.0046, val loss = 0.7040, val acc: 0.8523, time = 1.8620
Epoch 12: train loss = 0.0037, val loss = 0.6599, val acc: 0.8628, time = 1.8516
Epoch 13: train loss = 0.0060, val loss = 0.7189, val acc: 0.8556, time = 1.9215
Epoch 14: train loss = 0.0273, val loss = 0.5430, val acc: 0.8510, time = 1.7886
Epoch 15: train loss = 0.0559, val loss = 0.5782, val acc: 0.8573, time = 1.8280
Epoch 16: train loss = 0.0271, val loss = 0.6924, val acc: 0.8519, time = 1.8917
Epoch 17: train loss = 0.0089, val loss = 0.8317, val acc: 0.8523, time = 1.9212
Epoch 18: train loss = 0.0049, val loss = 0.8411, val acc: 0.8561, time = 1.9096
Epoch 19: train loss = 0.0091, val loss = 0.7192, val acc: 0.8523, time = 1.8796
Epoch 20: train loss = 0.0268, val loss = 0.6000, val acc: 0.8611, time = 1.7498
Epoch 21: train loss = 0.0230, val loss = 0.6211, val acc: 0.8611, time = 1.8845
Epoch 22: train loss = 0.0358, val loss = 0.6774, val acc: 0.8506, time = 1.7435
Epoch 23: train loss = 0.0329, val loss = 0.9656, val acc: 0.8200, time = 1.8652
Epoch 24: train loss = 0.0169, val loss = 0.8439, val acc: 0.8443, time = 1.8609
Epoch 25: train loss = 0.0170, val loss = 0.6603, val acc: 0.8552, time = 1.8016